#include <poll.h>
#include <string.h>
#include <sys/inotify.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <fcntl.h>
#include <iostream>
#include <spawn.h>
#include "utils.hpp"

AutoCloseFD::~AutoCloseFD() {
  close(fd_);
}

AutoMunmap::AutoMunmap(
  size_t size, int prot, int flags, int fd, off_t offset
) : addr_(mmap(0, size, prot, flags, fd, offset))
  , size_(size)
{
  if (addr_ == MAP_FAILED) {
    throw ErrorWithErrno("mmap failed.");
  }
}

AutoMunmap::~AutoMunmap() {
  munmap(addr_, size_);
}

void AutoMunmap::sync() const {
  if (msync(get(), size(), MS_SYNC) < 0) {
    throw ErrorWithErrno("msync failed.");
  }
}

AutoUnlink::~AutoUnlink() {
  unlinkat(dirfd_, filename_, 0);
}

AutoCreateAndDeleteFile::AutoCreateAndDeleteFile(
  int dirfd, const char *filename, const char* buf, size_t size, mode_t mode
) :
  dirfd_(dirfd),
  filename_(filename)
{
  create_and_write_file(dirfd, filename, buf, size, mode);
}

AutoCreateAndDeleteFile::AutoCreateAndDeleteFile(
  int dirfd, const char *filename, mode_t mode,
  const std::function<void(int fd)>& initFile
) :
  dirfd_(dirfd),
  filename_(filename)
{
  initFile(create_file(dirfd, filename, mode));
}

AutoCreateAndDeleteFile::~AutoCreateAndDeleteFile() {
  unlinkat(dirfd_, filename_, 0);
}

ScanDirAt::~ScanDirAt() {
  if (n_ >= 0) {
    for (int i = 0; i < n_; i++) {
      free(namelist_[i]);
    }
    free(namelist_);
  }
}

// Create a TCP socket and start listening. We will let the OS choose
// the port number.
int create_bind_and_listen_tcp() {
  // Create a socket for listening on the port.
  const int sock =
    socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC | SOCK_NONBLOCK, 0);
  if (sock < 0) {
    throw ErrorWithErrno("Failed to create socket.");
  }

  // Allow the port to be reused as soon as the program terminates.
  int one = 1;
  if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) {
    throw ErrorWithErrno("Failed to set SO_REUSEADDR.");
  }

  // Bind the port.
  struct sockaddr_in addr;
  memset(&addr, 0, sizeof(addr));
  addr.sin_family = AF_INET;
  addr.sin_port = 0; // Ask OS to choose a port number
  addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); // localhost

  if (bind(sock, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
    throw ErrorWithErrno("Error binding TCP socket to port.");
  }

  // Start listening.
  if (listen(sock, SOMAXCONN) < 0) {
    throw ErrorWithErrno("listen failed.");
  }

  return sock;
}

// Find out which port number the socket is bound to. We need this because
// we asked the OS to choose the port number for us (in
// `create_bind_and_listen_tcp`, above).
uint16_t getportnumber(const int sock) {
  struct sockaddr_in sin;
  socklen_t len = sizeof(sin);
  if (getsockname(sock, (struct sockaddr *)&sin, &len) < 0) {
    throw ErrorWithErrno("getsockname failed.");
  }
  return ntohs(sin.sin_port);
}

// Add an inotify watch.
int add_watch(const int inotify_fd, const char* filename, uint32_t mask) {
  std::cout << "adding watch for " << filename << "\n";
  const int wd = inotify_add_watch(inotify_fd, filename, mask);
  if (wd < 0) {
    throw ErrorWithErrno(
      std::string("inotify_add_watch of ") + filename + " failed."
    );
  }
  return wd;
}

// Create a symlink: `linkname` -> `target`
// `newdirfd` is used as the current directory if `linkname` is a relative
// path.
void createSymlink(
  const char* target, const int newdirfd, const char* linkname
) {
  if (symlinkat(target, newdirfd, linkname) < 0) {
    throw ErrorWithErrno(std::string("Could not create symlink ") + linkname);
  }
  std::cout << "symlink created: " << linkname << " -> " << target << "\n";
}

// Create a file. This function will throw an exception if the file already
// exists.
int create_file(int dirfd, const char *filename, mode_t mode) {
  const int fd =
    openat(dirfd, filename, O_CREAT | O_WRONLY | O_TRUNC | O_EXCL | O_CLOEXEC, mode);
  if (fd < 0) {
    throw ErrorWithErrno(std::string("Could not create ") + filename);
  }
  return fd;
}

void touch_file(int dirfd, const char *filename, mode_t mode) {
  const AutoCloseFD fd(
    openat(dirfd, filename, O_CREAT | O_WRONLY | O_CLOEXEC, mode)
  );
  if (fd.get() < 0) {
    throw ErrorWithErrno(std::string("Could not open ") + filename);
  }
  // Update the timestamp.
  if (futimens(fd.get(), 0) < 0) {
    throw ErrorWithErrno(
      std::string("Could not update timestamp of ") + filename
    );
  }
}

void append_file(int dirfd, const char *filename, const char* buf, size_t buflen, mode_t mode) {
  const AutoCloseFD fd(
    openat(dirfd, filename, O_RDWR | O_APPEND | O_CLOEXEC, mode)
  );
  if (fd.get() < 0) {
    throw ErrorWithErrno(std::string("Could not open ") + filename);
  }
  write_or_throw(fd.get(), buf, buflen);
  std::cout << "file appended: " << filename << "\n";
}

// Write the buffer to the file descriptor. Throw an exception if something
// goes wrong.
void write_or_throw(const int fd, const char* buf, size_t buflen) {
  const ssize_t n = write(fd, buf, buflen);
  if (n < 0) {
    throw ErrorWithErrno("write failed");
  }
  if (static_cast<size_t>(n) != buflen) {
    throw Error("incomplete write");
  }
}

// Create a file and write the contents of `buf` to it. This function
// will throw an exception if the file already exists.
void create_and_write_file(
  int dirfd, const char *filename, const char* buf, size_t buflen, mode_t mode
) {
  const AutoCloseFD fd(create_file(dirfd, filename, mode));
  if (buflen > 0) {
    write_or_throw(fd.get(), buf, buflen);
  }
  std::cout << "file created: " << filename << "\n";
}

// Utility for writing enormous strings to a file. Repeatedly writes `msg` to
// the file until exactly `totallen` bytes have been written. (The final copy
// of `msg` might get truncated.)
void write_repeated_buffer(
  const int fd, const char* msg, size_t msglen, size_t totallen
) {
  // Create a large block with 4096 copies of the message, to reduce the number
  // of calls to `write`.
  std::string block;
  block.reserve(msglen * 4096);
  for (size_t i = 0; i < 4096; i++) {
    block.append(msg, msglen);
  }

  const char* blockptr = block.c_str();
  size_t blocksize = block.size();
  size_t pos = 0;
  while (1) {
    pos += blocksize;
    if (pos <= totallen) {
      write_or_throw(fd, blockptr, blocksize);
    } else {
      // The block is too big. So we need to rewind and write out a
      // smaller number of bytes.
      pos -= blocksize;
      write_or_throw(fd, blockptr, totallen - pos);
      // We are done.
      return;
    }
  }
}

// Use `poll` to wait for the file descriptor to be readable.
void fd_wait_for_read(const int inotify_fd) {
  const nfds_t nfds = 1;
  struct pollfd pollfds[1] = {0};
  pollfds[0].fd = inotify_fd;
  pollfds[0].events = POLLIN;

  while (1) {
    const int poll_num = poll(pollfds, nfds, -1);
    if (unlikely(poll_num < 0)) {
      const int err = errno;
      if (err == EINTR) {
        continue;
      }
      throw ErrorWithErrno("poll failed");
    }

    if (likely(poll_num > 0)) {
      if (likely(pollfds[0].revents & POLLIN)) {
        break;
      }
    }
  }
}

// Read all the available input on the file descriptor. (We use this to
// reset inotify after it has reported an event.)
void drain_fd(const int fd) {
  char buf[4096];
  while (read(fd, buf, sizeof(buf)) > 0);
}

// Kill a child process and wait for it.
void kill_and_wait(const pid_t cpid, const int sig) {
    if (kill(cpid, sig) < 0) {
      throw ErrorWithErrno("kill() failed");
    }
    if (waitpid(cpid, 0, 0) < 0) {
      throw ErrorWithErrno("waitpid() failed");
    }
}

// Keep forking child processes until we get one with the desired PID.
pid_t fork_child_with_pid(
  size_t numtries, const std::function<bool(pid_t)>& isDesired
) {
  char prog[] = "/bin/sleep";
  char arg[] = "10s";
  char *const argv[3] = {prog, arg, 0};

  // It should take less than 64k iterations to get the desired PID,
  // but we'll set a limit so that we don't loop forever.
  for (size_t i = 0; i < numtries; i++) {
    pid_t cpid = 0;
    const int r = posix_spawn(&cpid, "/bin/sleep", 0, 0, argv, 0);
    if (r != 0) {
      throw ErrorWithErrno("posix_spawn failed.");
    }
    if (isDesired(cpid)) {
      // Successfully forked a child with the correct PID.
      std::cout << "Successfully started child process "
                << cpid << " after " << i << " iterations.\n";
      return cpid;
    }
    // The child PID is wrong, so kill it and try again.
    kill_and_wait(cpid, SIGTERM);
  }

  // Failed to create a process with the desired PID.
  throw Error("Failed to create a process with the desired PID.");
}

// Search `/proc/*/cmdline` to find the PID of a running program.
std::vector<pid_t> search_pids(const char *cmdline, size_t cmdline_len) {
  AutoCloseFD procdir_fd(open("/proc", O_PATH | O_CLOEXEC));
  if (procdir_fd.get() < 0) {
    throw ErrorWithErrno("Could not open /proc.");
  }
  ScanDirAt scanDir(procdir_fd.get());

  const int n = scanDir.size();
  std::vector<pid_t> result;
  for (int i = 0; i < n; i++) {
    const char* subdir_name = scanDir.get(i);
    AutoCloseFD subdir_fd(
      openat(procdir_fd.get(), subdir_name, O_PATH | O_CLOEXEC)
    );
    if (procdir_fd.get() < 0) {
      continue;
    }
    AutoCloseFD cmdline_fd(
      openat(subdir_fd.get(), "cmdline", O_RDONLY | O_CLOEXEC)
    );
    if (cmdline_fd.get() < 0) {
      continue;
    }

    // Check if the command line matches.
    char buf[0x1000];
    ssize_t r = read(cmdline_fd.get(), buf, sizeof(buf));
    if (r < 0 || static_cast<size_t>(r) < cmdline_len) {
      continue;
    }
    if (memcmp(buf, cmdline, cmdline_len) == 0) {
      // The name of the sub-directory is the PID.
      result.push_back(atoi(subdir_name));
    }
  }
  return result;
}

pid_t search_pid(const char *cmdline, size_t cmdline_len) {
  std::vector<pid_t> pids = search_pids(cmdline, cmdline_len);
  if (pids.size() == 1) {
    return pids[0];
  }
  return -1;
}
